import paddle

class SSHead(paddle.nn.Layer):
    def __init__(self, num_classes=20, in_channels=[256, 256, 256]):
        """
        初始化检测头部
        params:
        - num_classes (int): 物体类别数量
        - in_channels(list): 输入通道维数
        """
        super().__init__()

        # 添加输出列表
        self.num_head = len(in_channels) # 输出头部数量
        self.loc_list = []               # 位置卷积列表
        self.obj_list = []               # 物体卷积列表
        self.cls_list = []               # 类别卷积列表
        for i in range(self.num_head):   # 遍历输出头部
            # 添加位置卷积
            loc_item = self.add_sublayer(
                'loc_' + str(i),
                paddle.nn.Conv2D(
                    in_channels=in_channels[i], out_channels=4, kernel_size=1, stride=1, padding=0,
                    weight_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal()),
                    bias_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal())
                )
            )
            self.loc_list.append(loc_item)
            
            # 添加物体卷积
            obj_item = self.add_sublayer(
                'obj_' + str(i),
                paddle.nn.Conv2D(
                    in_channels=in_channels[i], out_channels=1, kernel_size=1, stride=1, padding=0,
                    weight_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal()),
                    bias_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal())
                )
            )
            self.obj_list.append(obj_item)
            
            # 添加类别卷积
            cls_item = self.add_sublayer(
                'cls_' + str(i),
                paddle.nn.Conv2D(
                    in_channels=in_channels[i], out_channels=num_classes, kernel_size=1, stride=1, padding=0,
                    weight_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal()),
                    bias_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal())
                )
            )
            self.cls_list.append(cls_item)
        
    def forward(self, c_list):
        p_list = []                                                        # 输出特征列表
        for i in range(self.num_head):                                     # 遍历输出头部
            loc_item = self.loc_list[i](c_list[i])                         # 输出位置特征
            obj_item = self.obj_list[i](c_list[i])                         # 输出物体特征
            cls_item = self.cls_list[i](c_list[i])                         # 输出类别特征
            p_item = paddle.concat([loc_item, obj_item, cls_item], axis=1) # 连接输出特征
            p_list.append(p_item)                                          # 添加输出列表
            
        return p_list